"""ViT"""
from typing import Callable, Optional

import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer

from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.drop_path import DropPath
from .layers.mlp import Mlp
from .layers.patch_dropout import PatchDropout
from .layers.patch_embed import PatchEmbed
from .layers.pos_embed import resample_abs_pos_embed
from .registry import register_model

__all__ = [
    "VisionTransformer",
    "vit_b_16_224",
    "vit_b_16_384",
    "vit_l_16_224",  # with pretrained weights
    "vit_l_16_384",
    "vit_b_32_224",  # with pretrained weights
    "vit_b_32_384",
    "vit_l_32_224",  # with pretrained weights
]


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "input_size": (3, 224, 224),
        "first_conv": "patch_embed.proj",
        "classifier": "head.classifier",
        **kwargs,
    }


default_cfgs = {
    "vit_b_16_224": _cfg(url=""),
    "vit_b_16_384": _cfg(
        url="", input_size=(3, 384, 384)
    ),
    "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt"),
    "vit_l_16_384": _cfg(
        url="", input_size=(3, 384, 384)
    ),
    "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt"),
    "vit_b_32_384": _cfg(
        url="", input_size=(3, 384, 384)
    ),
    "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt"),
}


# TODO: Flash Attention
class Attention(nn.Cell):
    """
    Attention layer implementation, Rearrange Input -> B x N x hidden size.

    Args:
        dim (int): The dimension of input features.
        num_heads (int): The number of attention heads. Default: 8.
        qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
        qk_norm (bool): Specifies whether to do normalization to q and k.
        attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
        proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = Attention(768, 12)
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        norm_layer: nn.Cell = nn.LayerNorm,
    ):
        super(Attention, self).__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = Tensor(self.head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
        self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()
        self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()

        self.attn_drop = Dropout(attn_drop)
        self.proj = nn.Dense(dim, dim)
        self.proj_drop = Dropout(proj_drop)

        self.mul = ops.Mul()
        self.reshape = ops.Reshape()
        self.transpose = ops.Transpose()
        self.unstack = ops.Unstack(axis=0)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)

    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim))
        qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = self.unstack(qkv)
        q, k = self.q_norm(q), self.k_norm(k)

        attn = self.q_matmul_k(q, k)
        attn = self.mul(attn, self.scale)

        attn = attn.astype(ms.float32)
        attn = ops.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        out = self.attn_matmul_v(attn, v)
        out = self.transpose(out, (0, 2, 1, 3))
        out = self.reshape(out, (b, n, c))
        out = self.proj(out)
        out = self.proj_drop(out)

        return out


class LayerScale(nn.Cell):
    """
    Layer scale, help ViT improve the training dynamic, allowing for the training
    of deeper high-capacity image transformers that benefit from depth

    Args:
        dim (int): The output dimension of attnetion layer or mlp layer.
        init_values (float): The scale factor. Default: 1e-5.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = LayerScale(768, 0.01)
    """
    def __init__(
        self,
        dim: int,
        init_values: float = 1e-5
    ):
        super(LayerScale, self).__init__()
        self.gamma = Parameter(initializer(init_values, dim))

    def construct(self, x):
        return self.gamma * x


class Block(nn.Cell):
    """
    Transformer block implementation.

    Args:
        dim (int): The dimension of embedding.
        num_heads (int): The number of attention heads.
        qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
        attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
        proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0.
        mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer.
        drop_path (float): The drop rate for drop path. Default: 0.0.
        act_layer (nn.Cell): Activation function which will be stacked on top of the
            normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU.
        norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution
            layer. Default: nn.LayerNorm.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = TransformerEncoder(768, 12, 12, 3072)
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        mlp_ratio: float = 4.,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0.,
        attn_drop: float = 0.,
        init_values: Optional[float] = None,
        drop_path: float = 0.,
        act_layer: nn.Cell = nn.GELU,
        norm_layer: nn.Cell = nn.LayerNorm,
        mlp_layer: Callable = Mlp,
    ):
        super(Block, self).__init__()
        self.norm1 = norm_layer((dim,))
        self.attn = Attention(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer((dim,))
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop
        )
        self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def construct(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class VisionTransformer(nn.Cell):
    '''
    ViT encoder, which returns the feature encoded by transformer encoder.
    '''
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        global_pool: str = 'token',
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        drop_rate: float = 0.,
        pos_drop_rate: float = 0.,
        patch_drop_rate: float = 0.,
        proj_drop_rate: float = 0.,
        attn_drop_rate: float = 0.,
        drop_path_rate: float = 0.,
        weight_init: bool = True,
        init_values: Optional[float] = None,
        no_embed_class: bool = False,
        pre_norm: bool = False,
        fc_norm: Optional[bool] = None,
        dynamic_img_size: bool = False,
        dynamic_img_pad: bool = False,
        act_layer: nn.Cell = nn.GELU,
        embed_layer: Callable = PatchEmbed,
        norm_layer: nn.Cell = nn.LayerNorm,
        mlp_layer: Callable = Mlp,
        class_token: bool = True,
        block_fn: Callable = Block,
        num_classes: int = 1000,
    ):
        super(VisionTransformer, self).__init__()
        assert global_pool in ('', 'avg', 'token')
        assert class_token or global_pool != 'token'
        use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm

        self.global_pool = global_pool
        self.num_prefix_tokens = 1 if class_token else 0
        self.no_embed_class = no_embed_class
        self.dynamic_img_size = dynamic_img_size
        self.dynamic_img_pad = dynamic_img_pad

        embed_args = {}
        if dynamic_img_size:
            # flatten deferred until after pos embed
            embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
        elif dynamic_img_pad:
            embed_args.update(dict(output_fmt='NHWC'))

        self.patch_embed = embed_layer(
            image_size=image_size,
            patch_size=patch_size,
            in_chans=in_channels,
            embed_dim=embed_dim,
            bias=not pre_norm,  # disable bias if pre-norm is used (e.g. CLIP)
            dynamic_img_pad=dynamic_img_pad,
            **embed_args,
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None
        embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
        self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim)))
        self.pos_drop = Dropout(pos_drop_rate)
        if patch_drop_rate > 0:
            self.patch_drop = PatchDropout(
                patch_drop_rate,
                num_prefix_tokens=self.num_prefix_tokens,
            )
        else:
            self.patch_drop = nn.Identity()

        self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity()
        dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.CellList([
            block_fn(
                dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
                attn_drop=attn_drop_rate, proj_drop=proj_drop_rate,
                mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values,
                act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer,
            ) for i in range(depth)
        ])

        self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity()
        self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity()
        self.head_drop = Dropout(drop_rate)
        self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        if weight_init:
            self._init_weights()

    def get_num_layers(self):
        return len(self.blocks)

    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def _init_weights(self):
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.weight.set_data(
                    initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(
                        initializer('zeros', cell.bias.shape, cell.bias.dtype)
                    )
            elif isinstance(cell, nn.LayerNorm):
                cell.gamma.set_data(
                    initializer('ones', cell.gamma.shape, cell.gamma.dtype)
                )
                cell.beta.set_data(
                    initializer('zeros', cell.beta.shape, cell.beta.dtype)
                )
            elif isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    initializer(HeUniform(), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(
                        initializer("zeros", cell.bias.shape, cell.bias.dtype)
                    )

    def _pos_embed(self, x):
        if self.dynamic_img_size or self.dynamic_img_pad:
            # bhwc format
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                self.pos_embed,
                (H, W),
                num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
            )
            x = ops.reshape(x, (B, -1, C))
        else:
            pos_embed = self.pos_embed

        if self.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + pos_embed
            if self.cls_token is not None:
                cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
                cls_tokens = cls_tokens.astype(x.dtype)
                x = ops.concat((cls_tokens, x), axis=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.cls_token is not None:
                cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
                cls_tokens = cls_tokens.astype(x.dtype)
                x = ops.concat((cls_tokens, x), axis=1)
            x = x + pos_embed

        return self.pos_drop(x)

    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

    def forward_head(self, x):
        if self.global_pool:
            x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        x = self.head(x)
        return x

    def construct(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


@register_model
def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_b_16_224"]
    model = VisionTransformer(
        image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_b_16_384"]
    model = VisionTransformer(
        image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_l_16_224"]
    model = VisionTransformer(
        image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_l_16_384"]
    model = VisionTransformer(
        image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_b_32_224"]
    model = VisionTransformer(
        image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_b_32_384"]
    model = VisionTransformer(
        image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model


@register_model
def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
    default_cfg = default_cfgs["vit_l_32_224"]
    model = VisionTransformer(
        image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
        num_classes=num_classes, **kwargs
    )
    if pretrained:
        load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)

    return model
